"""
Non-Overlap: Code to take in a set of raw detections and produce a set of non-overlapping detections from it.

Author: Jonathon Luiten
"""

import os
import sys
from multiprocessing.pool import Pool
from multiprocessing import freeze_support

from trackeval.baselines import baseline_utils as butils
from trackeval.utils import get_code_path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
code_path = get_code_path()
config = {
    "INPUT_FOL": os.path.join(
        code_path, "data/detections/rob_mots/{split}/raw_supplied/data/"
    ),
    "OUTPUT_FOL": os.path.join(
        code_path, "data/detections/rob_mots/{split}/non_overlap_supplied/data/"
    ),
    "SPLIT": "train",  # valid: 'train', 'val', 'test'.
    "Benchmarks": None,  # If None, all benchmarks in SPLIT.
    "Num_Parallel_Cores": None,  # If None, run without parallel.
    "THRESHOLD_NMS_MASK_IOU": 0.5,
}


def do_sequence(seq_file):

    # Load input data from file (e.g. provided detections)
    # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'}
    data = butils.load_seq(seq_file)

    # Converts data from a class-separated to a class-combined format.
    # data[t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles', 'cls'}
    data = butils.combine_classes(data)

    # Where to accumulate output data for writing out
    output_data = []

    # Run for each timestep.
    for timestep, t_data in enumerate(data):

        # Remove redundant masks by performing non-maximum suppression (NMS)
        t_data = butils.mask_NMS(t_data, nms_threshold=config["THRESHOLD_NMS_MASK_IOU"])

        # Perform non-overlap, to get non_overlapping masks.
        t_data = butils.non_overlap(t_data, already_sorted=True)

        # Save result in output format to write to file later.
        # Output Format = [timestep ID class score im_h im_w mask_RLE]
        for i in range(len(t_data["ids"])):
            row = [
                timestep,
                int(t_data["ids"][i]),
                t_data["cls"][i],
                t_data["scores"][i],
                t_data["im_hs"][i],
                t_data["im_ws"][i],
                t_data["mask_rles"][i],
            ]
            output_data.append(row)

    # Write results to file
    out_file = seq_file.replace(
        config["INPUT_FOL"].format(split=config["SPLIT"]),
        config["OUTPUT_FOL"].format(split=config["SPLIT"]),
    )
    butils.write_seq(output_data, out_file)

    print("DONE:", seq_file)


if __name__ == "__main__":

    # Required to fix bug in multiprocessing on windows.
    freeze_support()

    # Obtain list of sequences to run tracker for.
    if config["Benchmarks"]:
        benchmarks = config["Benchmarks"]
    else:
        benchmarks = [
            "davis_unsupervised",
            "kitti_mots",
            "youtube_vis",
            "ovis",
            "bdd_mots",
            "tao",
        ]
        if config["SPLIT"] != "train":
            benchmarks += ["waymo", "mots_challenge"]
    seqs_todo = []
    for bench in benchmarks:
        bench_fol = os.path.join(
            config["INPUT_FOL"].format(split=config["SPLIT"]), bench
        )
        seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)]

    # Run in parallel
    if config["Num_Parallel_Cores"]:
        with Pool(config["Num_Parallel_Cores"]) as pool:
            results = pool.map(do_sequence, seqs_todo)

    # Run in series
    else:
        for seq_todo in seqs_todo:
            do_sequence(seq_todo)
